

<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
  <meta charset="utf-8" />
  
  <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  
  <title>mindspore.nn.transformer &mdash; MindSpore master documentation</title>
  

  
  <link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
  <link rel="stylesheet" href="../_static/pygments.css" type="text/css" />

  
  

  
  

  

  
  <!--[if lt IE 9]>
    <script src="../_static/js/html5shiv.min.js"></script>
  <![endif]-->
  
    
      <script type="text/javascript" id="documentation_options" data-url_root="../" src="../_static/documentation_options.js"></script>
        <script src="../_static/jquery.js"></script>
        <script src="../_static/underscore.js"></script>
        <script src="../_static/doctools.js"></script>
        <script src="../_static/language_data.js"></script>
        <script async="async" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
    
    <script type="text/javascript" src="../_static/js/theme.js"></script>

    
    <link rel="index" title="Index" href="../genindex.html" />
    <link rel="search" title="Search" href="../search.html" />
    <link rel="next" title="mindspore.numpy" href="mindspore.numpy.html" />
    <link rel="prev" title="mindspore.nn.probability.distribution.Uniform" href="nn_probability/mindspore.nn.probability.distribution.Uniform.html" /> 
</head>

<body class="wy-body-for-nav">

   
  <div class="wy-grid-for-nav">
    
    <nav data-toggle="wy-nav-shift" class="wy-nav-side">
      <div class="wy-side-scroll">
        <div class="wy-side-nav-search" >
          

          
            <a href="../index.html" class="icon icon-home"> MindSpore
          

          
          </a>

          
            
            
          

          
<div role="search">
  <form id="rtd-search-form" class="wy-form" action="../search.html" method="get">
    <input type="text" name="q" placeholder="Search docs" />
    <input type="hidden" name="check_keywords" value="yes" />
    <input type="hidden" name="area" value="default" />
  </form>
</div>

          
        </div>

        
        <div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
          
            
            
              
            
            
              <p class="caption"><span class="caption-text">MindSpore Python API</span></p>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="mindspore.html">mindspore</a></li>
<li class="toctree-l1"><a class="reference internal" href="mindspore.common.initializer.html">mindspore.common.initializer</a></li>
<li class="toctree-l1"><a class="reference internal" href="mindspore.communication.html">mindspore.communication</a></li>
<li class="toctree-l1"><a class="reference internal" href="mindspore.compression.html">mindspore.compression</a></li>
<li class="toctree-l1"><a class="reference internal" href="mindspore.context.html">mindspore.context</a></li>
<li class="toctree-l1"><a class="reference internal" href="mindspore.dataset.html">mindspore.dataset</a></li>
<li class="toctree-l1"><a class="reference internal" href="mindspore.dataset.audio.html">mindspore.dataset.audio</a></li>
<li class="toctree-l1"><a class="reference internal" href="mindspore.dataset.config.html">mindspore.dataset.config</a></li>
<li class="toctree-l1"><a class="reference internal" href="mindspore.dataset.text.html">mindspore.dataset.text</a></li>
<li class="toctree-l1"><a class="reference internal" href="mindspore.dataset.transforms.html">mindspore.dataset.transforms</a></li>
<li class="toctree-l1"><a class="reference internal" href="mindspore.dataset.vision.html">mindspore.dataset.vision</a></li>
<li class="toctree-l1"><a class="reference internal" href="mindspore.mindrecord.html">mindspore.mindrecord</a></li>
<li class="toctree-l1"><a class="reference internal" href="mindspore.nn.html">mindspore.nn</a></li>
<li class="toctree-l1"><a class="reference internal" href="mindspore.nn.probability.html">mindspore.nn.probability</a></li>
<li class="toctree-l1 current"><a class="current reference internal" href="#">mindspore.nn.transformer</a></li>
<li class="toctree-l1"><a class="reference internal" href="mindspore.numpy.html">mindspore.numpy</a></li>
<li class="toctree-l1"><a class="reference internal" href="mindspore.ops.html">mindspore.ops</a></li>
<li class="toctree-l1"><a class="reference internal" href="mindspore.parallel.html">mindspore.parallel</a></li>
<li class="toctree-l1"><a class="reference internal" href="mindspore.parallel.nn.html">mindspore.parallel.nn</a></li>
<li class="toctree-l1"><a class="reference internal" href="mindspore.profiler.html">mindspore.profiler</a></li>
<li class="toctree-l1"><a class="reference internal" href="mindspore.scipy.html">mindspore.scipy</a></li>
<li class="toctree-l1"><a class="reference internal" href="mindspore.train.html">mindspore.train</a></li>
<li class="toctree-l1"><a class="reference internal" href="mindspore.boost.html">mindspore.boost</a></li>
</ul>
<p class="caption"><span class="caption-text">MindSpore C++ API</span></p>
<ul>
<li class="toctree-l1"><a class="reference external" href="https://www.mindspore.cn/lite/api/zh-CN/master/api_cpp/mindspore.html">MindSpore Lite↗</a></li>
</ul>

            
          
        </div>
        
      </div>
    </nav>

    <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">

      
      <nav class="wy-nav-top" aria-label="top navigation">
        
          <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
          <a href="../index.html">MindSpore</a>
        
      </nav>


      <div class="wy-nav-content">
        
        <div class="rst-content">
        
          

















<div role="navigation" aria-label="breadcrumbs navigation">

  <ul class="wy-breadcrumbs">
    
      <li><a href="../index.html" class="icon icon-home"></a> &raquo;</li>
        
      <li>mindspore.nn.transformer</li>
    
    
      <li class="wy-breadcrumbs-aside">
        
          
            <a href="../_sources/api_python/mindspore.nn.transformer.rst.txt" rel="nofollow"> View page source</a>
          
        
      </li>
    
  </ul>

  
  <hr/>
</div>
          <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
           <div itemprop="articleBody">
            
  <div class="section" id="mindspore-nn-transformer">
<h1>mindspore.nn.transformer<a class="headerlink" href="#mindspore-nn-transformer" title="Permalink to this headline">¶</a></h1>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>Transformer网络。这些是实验性接口，可能会修改或删除。</p>
</div>
<dl class="class">
<dt id="mindspore.nn.AttentionMask">
<em class="property">class </em><code class="sig-prename descclassname">mindspore.nn.</code><code class="sig-name descname">AttentionMask</code><span class="sig-paren">(</span><em class="sig-param">seq_length</em>, <em class="sig-param">parallel_config=default_dpmp_config</em><span class="sig-paren">)</span><a class="headerlink" href="#mindspore.nn.AttentionMask" title="Permalink to this definition">¶</a></dt>
<dd><p>从输入掩码中获取下三角矩阵。输入掩码是值为1或0的二维Tensor (batch_size, seq_length)。1表示当前位置是一个有效的标记，其他值则表示当前位置不是一个有效的标记。</p>
<p><strong>参数：</strong></p>
<ul class="simple">
<li><p><strong>seq_length</strong> (int) - 表示输入Tensor的序列长度。</p></li>
<li><p><strong>parallel_config</strong> (OpParallelConfig) - 表示并行配置。默认值为 <cite>default_dpmp_config</cite> ，表示一个带有默认参数的 <cite>OpParallelConfig</cite> 实例。</p></li>
</ul>
<p><strong>输入：</strong></p>
<ul class="simple">
<li><p><strong>input_mask</strong> (Tensor) - 掩码矩阵，shape为(batch_size, seq_length)，表示每个位置是否为有效输入。</p></li>
</ul>
<p><strong>输出：</strong></p>
<p>Tensor，表示shape为(batch_size, seq_length, seq_length)的注意力掩码矩阵。</p>
<p><strong>异常：</strong></p>
<ul class="simple">
<li><p><strong>TypeError</strong> - <cite>seq_length</cite> 不是整数。</p></li>
<li><p><strong>ValueError</strong> - <cite>seq_length</cite> 不是正数。</p></li>
<li><p><strong>TypeError</strong> - <cite>parallel_config</cite> 不是OpParallelConfig的子类。</p></li>
</ul>
<p><strong>支持平台：</strong></p>
<p><code class="docutils literal notranslate"><span class="pre">Ascend</span></code> <code class="docutils literal notranslate"><span class="pre">GPU</span></code></p>
<p><strong>样例：</strong></p>
<div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">mindspore.nn.transformer</span> <span class="kn">import</span> <span class="n">AttentionMask</span>
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">mindspore</span> <span class="kn">import</span> <span class="n">Tensor</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">mask</span> <span class="o">=</span> <span class="n">AttentionMask</span><span class="p">(</span><span class="n">seq_length</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">mask_array</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">]],</span> <span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">inputs</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">mask_array</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">res</span> <span class="o">=</span> <span class="n">mask</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">res</span><span class="p">)</span>
<span class="go">[[[1. 0. 0. 0]</span>
<span class="go">[1. 1. 0. 0]</span>
<span class="go">[1. 1. 1. 0]</span>
<span class="go">[0. 0. 0. 0]]]</span>
</pre></div>
</div>
</dd></dl>

<dl class="class">
<dt id="mindspore.nn.transformer.VocabEmbedding">
<em class="property">class </em><code class="sig-prename descclassname">mindspore.nn.transformer.</code><code class="sig-name descname">VocabEmbedding</code><span class="sig-paren">(</span><em class="sig-param">vocab_size</em>, <em class="sig-param">embedding_size</em>, <em class="sig-param">parallel_config=default_embedding_parallel_config</em>, <em class="sig-param">param_init=&quot;normal&quot;</em><span class="sig-paren">)</span><a class="headerlink" href="#mindspore.nn.transformer.VocabEmbedding" title="Permalink to this definition">¶</a></dt>
<dd><p>根据输入的索引查找参数表中的行作为返回值。当设置并行模式为 <cite>AUTO_PARALLEL_MODE</cite> 时，如果parallel_config.vocab_emb_dp为True时，那么embedding lookup表采用数据并行的方式，数据并行度为 <cite>parallel_config.data_parallel</cite> ，否则按 <cite>parallel_config.model_parallel</cite> 对embedding表中的第0维度进行切分。</p>
<p><strong>参数：</strong></p>
<ul class="simple">
<li><p><strong>vocab_size</strong> （int) - 表示查找表的大小。</p></li>
<li><p><strong>embedding_size</strong> （int）- 表示查找表中每个嵌入向量的大小。</p></li>
<li><p><strong>param_init</strong> （Union[Tensor, str, Initializer, numbers.Number]）- 表示embedding_table的Initializer。当指定字符串时，请参见 <cite>initializer</cite> 类了解字符串的值。默认值：’normal’。</p></li>
<li><p><strong>parallel_config</strong> (EmbeddingOpParallelConfig) - 表示网络的并行配置。默认值为 <cite>default_embedding_parallel_config</cite> ，表示带有默认参数的 <cite>EmbeddingOpParallelConfig</cite> 实例。</p></li>
</ul>
<p><strong>输入：</strong></p>
<ul class="simple">
<li><p><strong>input_ids</strong> (Tensor) - shape为(batch_size, seq_length)的输入，其数据类型为int32。</p></li>
</ul>
<p><strong>输出：</strong></p>
<p>Tuple，表示一个包含(<cite>output</cite>, <cite>embedding_table</cite>)的元组。</p>
<ul class="simple">
<li><p><strong>output</strong> (Tensor) - shape为(batch_size, seq_length, embedding_size)嵌入向量查找结果。</p></li>
<li><p><strong>weight</strong> (Tensor) - shape为(vocab_size, embedding_size)的嵌入表。</p></li>
</ul>
<p><strong>异常：</strong></p>
<ul class="simple">
<li><p><strong>ValueError</strong> - parallel_config.vocab_emb_dp为True时，词典的大小不是parallel_config.model_parallel的倍数。</p></li>
<li><p><strong>ValueError</strong> - <cite>vocab_size</cite> 不是正值。</p></li>
<li><p><strong>ValueError</strong> - <cite>embedding_size</cite> 不是正值。</p></li>
<li><p><strong>TypeError</strong> - <cite>parallel_config</cite> 不是OpParallelConfig的子类。</p></li>
</ul>
<p><strong>支持平台：</strong></p>
<p><code class="docutils literal notranslate"><span class="pre">Ascend</span></code> <code class="docutils literal notranslate"><span class="pre">GPU</span></code></p>
<p><strong>样例：</strong></p>
<div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">mindspore.nn.transformer</span> <span class="kn">import</span> <span class="n">VocabEmbedding</span>
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">mindspore</span> <span class="kn">import</span> <span class="n">Tensor</span>
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">mindspore</span> <span class="kn">import</span> <span class="n">dtype</span> <span class="k">as</span> <span class="n">mstype</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">model</span> <span class="o">=</span> <span class="n">VocabEmbedding</span><span class="p">(</span><span class="n">vocab_size</span><span class="o">=</span><span class="mi">30</span><span class="p">,</span> <span class="n">embedding_size</span><span class="o">=</span><span class="mi">30</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">tensor</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">20</span><span class="p">,</span> <span class="mi">15</span><span class="p">)),</span> <span class="n">mstype</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">output</span><span class="p">,</span> <span class="n">table</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">tensor</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">output</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(20, 15, 30)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">table</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(30, 30)</span>
</pre></div>
</div>
</dd></dl>

<dl class="class">
<dt id="mindspore.nn.transformer.MultiHeadAttention">
<em class="property">class </em><code class="sig-prename descclassname">mindspore.nn.transformer.</code><code class="sig-name descname">MultiHeadAttention</code><span class="sig-paren">(</span><em class="sig-param">batch_size</em>, <em class="sig-param">src_seq_length</em>, <em class="sig-param">tgt_seq_length</em>, <em class="sig-param">hidden_size</em>, <em class="sig-param">num_heads</em>, <em class="sig-param">hidden_dropout_rate=0.1</em>, <em class="sig-param">attention_dropout_rate=0.1</em>, <em class="sig-param">compute_dtype=mstype.float16</em>, <em class="sig-param">softmax_compute_type=mstype.float32</em>, <em class="sig-param">param_init_type=mstype.float32</em>, <em class="sig-param">use_past=False</em>, <em class="sig-param">parallel_config=default_dpmp_config</em><span class="sig-paren">)</span><a class="headerlink" href="#mindspore.nn.transformer.MultiHeadAttention" title="Permalink to this definition">¶</a></dt>
<dd><p>论文 <a class="reference external" href="https://arxiv.org/pdf/1706.03762v5.pdf">Attention Is All You Need</a> 中所述的多头注意力的实现。给定src_seq_length长度的query向量，tgt_seq_length长度的key向量和value，注意力计算流程如下：</p>
<p>其中， <cite>head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)</cite> 。注意：输出层的投影计算中带有偏置参数。</p>
<p>如果query tensor、key tensor和value tensor相同，则上述即为自注意力机制的计算过程。</p>
<p><strong>参数：</strong></p>
<ul class="simple">
<li><p><strong>batch_size</strong> (int) - 表示训练批次的大小。</p></li>
<li><p><strong>src_seq_length</strong> (int) - 表示query向量的序列长度。</p></li>
<li><p><strong>tgt_seq_length</strong> (int) - 表示key向量和value向量的序列长度。</p></li>
<li><p><strong>hidden_size</strong> (int) - 表示输入的向量大小。</p></li>
<li><p><strong>num_heads</strong> (int) - 表示注意力机制中头的数量。</p></li>
<li><p><strong>hidden_dropout_rate</strong> (float) - 表示最后dense输出的丢弃率。默认值：0.1</p></li>
<li><p><strong>attention_dropout_rate</strong> (float) - 表示注意力score的丢弃率。默认值：0.1</p></li>
<li><p><strong>compute_dtype</strong> (dtype.Number) - 表示dense中矩阵乘法的计算类型。默认值：dtype.float16。其值应为dtype.float32或dtype.float16。</p></li>
<li><p><strong>param_init_type</strong> (dtype.Number) - 表示模块的参数初始化类型。默认值：dtype.float32。其值应为dtype.float32或dtype.float16。</p></li>
<li><p><strong>softmax_compute_type</strong> (dtype.Number) - 表示softmax计算模块的类型。默认值：dtype.float32。  其值应为dtype.float32或dtype.float16。</p></li>
<li><p><strong>use_past</strong> (bool) - 使用过去状态进行计算，用于增量预测。例如，如果我们有两个单词，想生成十个或以上单词。我们只需要计算一次这两个单词的状态，然后逐个生成下一个单词。当use_past为True时，有两个步骤可以执行预测。
第一步是通过 <cite>model.add_flags_recursive(is_first_iteration=True)</cite> 将is_first_iteration设为True，并传递完整的输入。然后，通过 <cite>model.add_flags_recursive(is_first_iteration=False)</cite> 将is_first_iteration设为False。此时，传递step的输入tensor，并对其进行循环。默认值：False</p></li>
<li><p><strong>parallel_config</strong> (OpParallelConfig) - 表示并行配置。默认值为 <cite>default_dpmp_config</cite> ，表示一个带有参数的 <cite>OpParallelConfig</cite> 实例。</p></li>
</ul>
<p><strong>输入：</strong></p>
<ul class="simple">
<li><p><strong>query_tensor</strong> (Tensor) - use_past为False或is_first_iteration为True时，表示shape为(batch_size, src_seq_length, hidden_size)或(batch_size * src_seq_length, hidden_size)的query向量。否则，shape必须为(batch_size, 1, hidden_size)。</p></li>
<li><p><strong>key_tensor</strong> (Tensor) - use_past为False或is_first_iteration为True时，表示shape为(batch_size, tgt_seq_length, hidden_size)或(batch_size * tgt_seq_length, hidden_size)的key向量。否则，shape必须为(batch_size, 1, hidden_size)。</p></li>
<li><p><strong>value_tensor</strong> (Tensor) - use_past为False或is_first_iteration为True时，表示shape为(batch_size, tgt_seq_length, hidden_size)或(batch_size * tgt_seq_length, hidden_size)的value向量。否则，shape必须为(batch_size, 1, hidden_size)。</p></li>
<li><p><strong>attention_mask</strong> (Tensor) - use_past为False或is_first_iteration为True时，表示shape为(batch_size, src_seq_length, tgt_seq_length)的注意力掩码矩阵。否则，shape必须为(batch_size, 1, tgt_seq_length)。</p></li>
<li><p><strong>key_past</strong> (Tensor) - shape为(batch_size, num_heads, size_per_head, tgt_seq_length)的Float16 tensor， 表示过去所计算的key向量。
当use_past为True时，需要传入非None值用于增量预测。默认值为None。</p></li>
<li><p><strong>value_past</strong> (Tensor) - shape为(batch_size, num_heads, tgt_seq_length, size_per_head)的Float16 tensor，表示过去所计算的value向量。
当use_past为True时，需要传入非None值用于增量预测。默认值为None。</p></li>
<li><p><strong>batch_valid_length</strong> (Tensor) - shape为(batch_size,)的Int32 tensor，表示已经计算的token索引。
当use_past为True时，需要传入非None值用于增量预测。默认值为None。</p></li>
</ul>
<p><strong>输出：</strong></p>
<p>Tuple，表示一个包含(<cite>output</cite>, <cite>layer_present</cite>)的元组。</p>
<ul class="simple">
<li><p><strong>output</strong> (Tensor) - Tensor。use_past为False或is_first_iteration为True时，表示shape为(batch_size, src_seq_length, hidden_size)或(batch_size * src_seq_length, hidden_size)的层输出的float tensor。否则，shape将为(batch_size, 1, hidden_size)。</p></li>
<li><p><strong>layer_present</strong> (Tuple) - 表示shape为((batch_size, num_heads, size_per_head, tgt_seq_length)或(batch_size, num_heads, tgt_seq_length, size_per_head))的投影key向量和value向量的Tensor的元组。</p></li>
</ul>
<p><strong>支持平台：</strong></p>
<p><code class="docutils literal notranslate"><span class="pre">Ascend</span></code> <code class="docutils literal notranslate"><span class="pre">GPU</span></code></p>
<p><strong>样例：</strong></p>
<div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">mindspore.nn.transformer</span> <span class="kn">import</span> <span class="n">MultiHeadAttention</span>
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">mindspore</span> <span class="kn">import</span> <span class="n">dtype</span> <span class="k">as</span> <span class="n">mstype</span>
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">mindspore</span> <span class="kn">import</span> <span class="n">Tensor</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">model</span> <span class="o">=</span> <span class="n">MultiHeadAttention</span><span class="p">(</span><span class="n">batch_size</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">hidden_size</span><span class="o">=</span><span class="mi">15</span><span class="p">,</span> <span class="n">src_seq_length</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span> <span class="n">tgt_seq_length</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span>
<span class="gp">... </span>                           <span class="n">num_heads</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">from_tensor</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">20</span><span class="p">,</span> <span class="mi">15</span><span class="p">)),</span> <span class="n">mstype</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">to_tensor</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">20</span><span class="p">,</span> <span class="mi">15</span><span class="p">)),</span> <span class="n">mstype</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">attention_mask</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">20</span><span class="p">,</span> <span class="mi">20</span><span class="p">)),</span> <span class="n">mstype</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">attn_out</span><span class="p">,</span> <span class="n">past</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">from_tensor</span><span class="p">,</span> <span class="n">to_tensor</span><span class="p">,</span> <span class="n">to_tensor</span><span class="p">,</span> <span class="n">attention_mask</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">attn_out</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 20, 15)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">past</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 3, 5, 20)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">past</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 3, 20, 5)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="c1"># When use use_past=True, it includes two steps to implement the incremental prediction.</span>
<span class="gp">&gt;&gt;&gt; </span><span class="c1"># Step 1: set is_first_iteration=True, and input the full sequence length&#39;s state.</span>
<span class="gp">&gt;&gt;&gt; </span><span class="c1"># We need to prepare the memory parameters for saving key and value states firstly.</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">model</span> <span class="o">=</span> <span class="n">MultiHeadAttention</span><span class="p">(</span><span class="n">batch_size</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">hidden_size</span><span class="o">=</span><span class="mi">15</span><span class="p">,</span> <span class="n">src_seq_length</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span> <span class="n">tgt_seq_length</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span>
<span class="gp">... </span>                           <span class="n">num_heads</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">use_past</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">key_past</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">20</span><span class="p">)),</span> <span class="n">mstype</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">value_past</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">20</span><span class="p">,</span> <span class="mi">5</span><span class="p">)),</span> <span class="n">mstype</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">batch_valid_length</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">2</span><span class="p">,)),</span> <span class="n">mstype</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="c1"># Set is_first_iteration=True to generate the full memory states</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">model</span><span class="o">.</span><span class="n">add_flags_recursive</span><span class="p">(</span><span class="n">is_first_iteration</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">attn_out</span><span class="p">,</span> <span class="n">past</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">from_tensor</span><span class="p">,</span> <span class="n">to_tensor</span><span class="p">,</span> <span class="n">to_tensor</span><span class="p">,</span> <span class="n">attention_mask</span><span class="p">,</span> <span class="n">key_past</span><span class="p">,</span> <span class="n">value_past</span><span class="p">,</span>
<span class="gp">... </span>                       <span class="n">batch_valid_length</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">attn_out</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 20, 15)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">past</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 3, 5, 20)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">past</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 3, 20, 5)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">from_tensor</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">15</span><span class="p">)),</span> <span class="n">mstype</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">to_tensor</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">15</span><span class="p">)),</span> <span class="n">mstype</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">attention_mask</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">20</span><span class="p">)),</span> <span class="n">mstype</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="c1"># Step 2: set is_first_iteration=False, and pass the single word to run the prediction rather than the full</span>
<span class="gp">&gt;&gt;&gt; </span><span class="c1"># sequence.</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">model</span><span class="o">.</span><span class="n">add_flags_recursive</span><span class="p">(</span><span class="n">is_first_iteration</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">attn_out</span><span class="p">,</span> <span class="n">past</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">from_tensor</span><span class="p">,</span> <span class="n">to_tensor</span><span class="p">,</span> <span class="n">to_tensor</span><span class="p">,</span> <span class="n">attention_mask</span><span class="p">,</span> <span class="n">key_past</span><span class="p">,</span> <span class="n">value_past</span><span class="p">,</span>
<span class="gp">... </span>                       <span class="n">batch_valid_length</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">attn_out</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 1, 15)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">past</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 3, 5, 20)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">past</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 3, 20, 5)</span>
</pre></div>
</div>
</dd></dl>

<dl class="class">
<dt id="mindspore.nn.transformer.FeedForward">
<em class="property">class </em><code class="sig-prename descclassname">mindspore.nn.transformer.</code><code class="sig-name descname">FeedForward</code><span class="sig-paren">(</span><em class="sig-param">hidden_size</em>, <em class="sig-param">ffn_hidden_size</em>, <em class="sig-param">dropout_rate</em>, <em class="sig-param">hidden_act=&quot;gelu&quot;</em>, <em class="sig-param">expert_num=1</em>, <em class="sig-param">param_init_type=mstype.float32</em>, <em class="sig-param">parallel_config=default_dpmp_config</em><span class="sig-paren">)</span><a class="headerlink" href="#mindspore.nn.transformer.FeedForward" title="Permalink to this definition">¶</a></dt>
<dd><p>具有两层线性层的多层感知器，并行在最终输出上使用Dropout。第一层前馈层将输入维度从hidden_size投影到ffn_hidden_size，并在中间应用激活层。第二个线性将该维度从ffn_hidden_size投影到hidden_size。配置parallel_config之后，
第一个前馈层的权重将在输入维度上被分片，第二个线性在输出维度上进行切分。总体过程如下</p>
<p>其中 <span class="math notranslate nohighlight">\(W_1, W_2, b_1\)</span> 和 <span class="math notranslate nohighlight">\(b_2\)</span> 为可训练参数。</p>
<p><strong>参数：</strong></p>
<ul class="simple">
<li><p><strong>hidden_size</strong> (int) - 表示输入的维度。</p></li>
<li><p><strong>ffn_hidden_size</strong> (int) - 表示中间隐藏大小。</p></li>
<li><p><strong>dropout_rate</strong> (float) - 表示第二个线性输出的丢弃率。</p></li>
<li><p><strong>hidden_act</strong> (str) - 表示第一层前馈层的激活。其值可为’relu’、’relu6’、’tanh’、’gelu’、’fast_gelu’、’elu’、’sigmoid’、’prelu’、’leakyrelu’、’hswish’、’hsigmoid’、’logsigmoid’等等。默认值：gelu。</p></li>
<li><p><strong>expert_num</strong> (int) - 表示线性中使用的专家数量。对于expert_num &gt; 1用例，使用BatchMatMul。BatchMatMul中的第一个维度表示expert_num。默认值：1</p></li>
<li><p><strong>param_init_type</strong> (dtype.Number) - 表示参数初始化类型。其值应为dtype.float32或dtype.float16。默认值：dtype.float32</p></li>
<li><p><strong>parallel_config</strong> (OpParallelConfig) - 表示并行配置。更多详情，请参见 <cite>OpParallelConfig</cite> 。默认值为 <cite>default_dpmp_config</cite> ，表示一个带有默认参数的 <cite>OpParallelConfig</cite> 实例。</p></li>
</ul>
<p><strong>输入：</strong></p>
<ul class="simple">
<li><p><strong>x</strong> (Tensor) - 应为 <cite>[batch, seq_length, hidden_size]或[batch * seq_length, hidden_size]</cite> 。表示浮点Tensor。</p></li>
</ul>
<p><strong>输出：</strong></p>
<p>Tensor，表示映射后该层的输出。shape为 <cite>[batch, seq_length, hidden_size]</cite> 或 <cite>[batch * seq_length, hidden_size]</cite> 。</p>
<p><strong>异常：</strong></p>
<ul class="simple">
<li><p><strong>ValueError</strong> - <cite>hidden_act</cite> 不是字符串。</p></li>
<li><p><strong>TypeError</strong> - <cite>parallel_config</cite> 不是OpParallelConfig的子类。</p></li>
<li><p><strong>ValueError</strong> - <cite>ffn_hidden_size</cite> 不是parallel_config中model_parallel的倍数。</p></li>
<li><p><strong>ValueError</strong> - <cite>hidden_size</cite> 不是parallel_config中model_parallel的倍数。</p></li>
</ul>
<p><strong>支持平台：</strong></p>
<p><code class="docutils literal notranslate"><span class="pre">Ascend</span></code> <code class="docutils literal notranslate"><span class="pre">GPU</span></code></p>
<p><strong>样例：</strong></p>
<div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">mindspore.nn.transformer</span> <span class="kn">import</span> <span class="n">FeedForward</span>
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">mindspore</span> <span class="kn">import</span> <span class="n">dtype</span> <span class="k">as</span> <span class="n">mstype</span>
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">mindspore</span> <span class="kn">import</span> <span class="n">Tensor</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">model</span> <span class="o">=</span> <span class="n">FeedForward</span><span class="p">(</span><span class="n">hidden_size</span><span class="o">=</span><span class="mi">15</span><span class="p">,</span> <span class="n">ffn_hidden_size</span><span class="o">=</span><span class="mi">30</span><span class="p">,</span> <span class="n">dropout_rate</span><span class="o">=</span><span class="mf">0.1</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">tensor</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">20</span><span class="p">,</span> <span class="mi">15</span><span class="p">)),</span> <span class="n">mstype</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">output</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">tensor</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">output</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 20, 15)</span>
</pre></div>
</div>
</dd></dl>

<dl class="class">
<dt id="mindspore.nn.transformer.TransformerEncoder">
<em class="property">class </em><code class="sig-prename descclassname">mindspore.nn.transformer.</code><code class="sig-name descname">TransformerEncoder</code><span class="sig-paren">(</span><em class="sig-param">batch_size</em>, <em class="sig-param">num_layers</em>, <em class="sig-param">hidden_size</em>, <em class="sig-param">ffn_hidden_size</em>, <em class="sig-param">seq_length</em>, <em class="sig-param">num_heads</em>, <em class="sig-param">attention_dropout_rate=0.1</em>, <em class="sig-param">hidden_dropout_rate=0.1</em>, <em class="sig-param">hidden_act=&quot;gelu&quot;</em>, <em class="sig-param">post_layernorm_residual=False</em>, <em class="sig-param">layernorm_compute_type=mstype.float32</em>, <em class="sig-param">softmax_compute_type=mstype.float32</em>, <em class="sig-param">param_init_type=mstype.float32</em>, <em class="sig-param">lambda_func=None</em>, <em class="sig-param">offset=0</em>, <em class="sig-param">use_past=False</em>, <em class="sig-param">moe_config=default_moe_config</em>, <em class="sig-param">parallel_config=default_transformer_config</em><span class="sig-paren">)</span><a class="headerlink" href="#mindspore.nn.transformer.TransformerEncoder" title="Permalink to this definition">¶</a></dt>
<dd><p>Transformer中的编码器模块，具有多层堆叠的 <cite>TransformerEncoderLayer</cite> ，包括多头自注意力层和前馈层。</p>
<p><strong>参数：</strong></p>
<ul class="simple">
<li><p><strong>batch_size</strong> (int) - 表示输入tensor的批次大小。</p></li>
<li><p><strong>num_layers</strong> (int) - 表示 <cite>TransformerEncoderLayer</cite> 的层。</p></li>
<li><p><strong>hidden_size</strong> (int) - 表示输入的隐藏大小。</p></li>
<li><p><strong>ffn_hidden_size</strong> (int) - 表示前馈层中bottleneck的隐藏大小。</p></li>
<li><p><strong>seq_length</strong> (int) - 表示输入序列长度。</p></li>
<li><p><strong>num_heads</strong> (int) - 表示注意力头的数量。</p></li>
<li><p><strong>hidden_dropout_rate</strong> (float) - 表示作用在隐藏层输出的丢弃率。默认值：0.1</p></li>
<li><p><strong>attention_dropout_rate</strong> (float) - 表示注意力score的丢弃率。默认值：0.1</p></li>
<li><p><strong>post_layernorm_residual</strong> (bool) - 表示是否在LayerNorm之前使用残差，即是否选择残差为Post-LayerNorm或者Pre-LayerNorm。默认值：False</p></li>
<li><p><strong>hidden_act</strong> (str) - 表示内部前馈层的激活函数。其值可为’relu’、’relu6’、’tanh’、’gelu’、’fast_gelu’、’elu’、’sigmoid’、’prelu’、’leakyrelu’、’hswish’、’hsigmoid’、’logsigmoid’等等。默认值：gelu。</p></li>
<li><p><strong>layernorm_compute_type</strong> (dtype.Number) - 表示LayerNorm的计算类型。其值应为dtype.float32或dtype.float16。默认值为dtype.float32。</p></li>
<li><p><strong>softmax_compute_type</strong> (dtype.Number) - 表示注意力中softmax的计算类型。其值应为dtype.float32或dtype.float16。默认值为mstype.float32。</p></li>
<li><p><strong>param_init_type</strong> (dtype.Number) - 表示模块的参数初始化类型。其值应为dtype.float32或dtype.float16。默认值为dtype.float32。</p></li>
<li><p><strong>use_past</strong> (bool) - 使用过去状态进行计算，用于增量预测。例如，如果我们有两个单词，想生成十个或以上单词。我们只需要计算一次这两个单词的状态，然后逐个生成下一个单词。当use_past为True时，有两个步骤可以运行预测。第一步是通过 <cite>model.add_flags_recursive(is_first_iteration=True)</cite> 将is_first_iteration设为True，并传递完整的输入。然后，通过 <cite>model.add_flags_recursive(is_first_iteration=False)</cite> 将is_first_iteration设为False。此时，传递step的输入tensor，并对其进行环回。默认值：False</p></li>
<li><p><strong>lambda_func</strong> - 表示设置融合索引、pipeline阶段和重计算属性的函数。如果用户想确定pipeline阶段和梯度聚合融合，用户可以传递一个接受 <cite>network</cite> 、 <cite>layer_id</cite> 、 <cite>offset</cite> 、 <cite>parallel_config</cite> 和 <cite>layers</cite> 的函数。 <cite>network(Cell)</cite> 表示transformer块， <cite>layer_id(int)</cite> 表示当前模块的层索引，从零开始计数， <cite>offset(int)</cite> 表示如果网络中还有其他模块，则layer_index需要一个偏置。pipeline的默认设置为： <cite>(layer_id + offset) // (layers / pipeline_stage)</cite> 。</p></li>
<li><p><strong>offset</strong> (int) - 表示 <cite>decoder</cite> 的初始层索引。其用于设置梯度聚合的融合值和流水线并行的stage值。</p></li>
<li><p><strong>moe_config</strong> (MoEConfig) - 表示MoE (Mixture of Expert)的配置。</p></li>
<li><p><strong>parallel_config</strong> (TransformerOpParallelConfig) - 表示并行配置。默认值为 <cite>default_transformer_config</cite> ，表示带有默认参数的 <cite>TransformerOpParallelConfig</cite> 实例。</p></li>
</ul>
<p><strong>输入：</strong></p>
<ul class="simple">
<li><p><strong>hidden_states</strong> (Tensor) - Tensor。如果use_past为False或者is_first_iteration为True，shape为[batch_size, seq_length, hidden_size]或者[batch_size * seq_length, hidden_size]。否则，shape应为[batch_size, 1, hidden_size]。</p></li>
<li><p><strong>attention_mask</strong> (Tensor) - Tensor，表示shape为[[batch_size, seq_length, seq_length]的注意力掩码。</p></li>
<li><p><strong>init_reset</strong> (Tensor) - shape为[1]的bool tensor，用于清除增量预测中使用的past key参数和past value参数。仅当use_past为True时有效。默认值为True。</p></li>
<li><p><strong>batch_valid_length</strong> (Tensor) - shape为[batch_size]的Int32 tensor，表示过去所计算的索引。当use_past为True时，它用于增量预测。默认值为None。</p></li>
</ul>
<p><strong>输出：</strong></p>
<p>Tuple，表示一个包含(<cite>output</cite>, <cite>layer_present</cite>)的元组。</p>
<ul class="simple">
<li><p><strong>output</strong> (Tensor) - use_past为False或is_first_iteration为True时，表示shape为(batch_size, seq_length, hidden_size)或(batch_size * seq_length, hidden_size)的层输出的float tensor。否则，shape将为(batch_size, 1, hidden_size)。</p></li>
<li><p><strong>layer_present</strong> (Tuple) - 大小为num_layers的元组，其中每个元组都包含shape为((batch_size, num_heads, size_per_head, seq_length)或(batch_size, num_heads, seq_length, size_per_head))的投影key向量和value向量的Tensor。</p></li>
</ul>
<p><strong>支持平台：</strong></p>
<p><code class="docutils literal notranslate"><span class="pre">Ascend</span></code> <code class="docutils literal notranslate"><span class="pre">GPU</span></code></p>
<p><strong>样例：</strong></p>
<div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">mindspore</span> <span class="kn">import</span> <span class="n">dtype</span> <span class="k">as</span> <span class="n">mstype</span>
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">mindspore.nn.transformer</span> <span class="kn">import</span> <span class="n">TransformerEncoder</span>
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">mindspore</span> <span class="kn">import</span> <span class="n">Tensor</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">model</span> <span class="o">=</span> <span class="n">TransformerEncoder</span><span class="p">(</span><span class="n">batch_size</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">num_layers</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">hidden_size</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">ffn_hidden_size</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">seq_length</span><span class="o">=</span><span class="mi">16</span><span class="p">,</span>
<span class="gp">... </span>                           <span class="n">num_heads</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">encoder_input_value</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">16</span><span class="p">,</span> <span class="mi">8</span><span class="p">)),</span> <span class="n">mstype</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">encoder_input_mask</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">16</span><span class="p">,</span> <span class="mi">16</span><span class="p">)),</span> <span class="n">mstype</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">output</span><span class="p">,</span> <span class="n">past</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">encoder_input_value</span><span class="p">,</span> <span class="n">encoder_input_mask</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">output</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 16, 8)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">past</span><span class="p">))</span>
<span class="go">2</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">past</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 2, 4, 16)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">past</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 2, 16, 4)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="c1"># When use use_past=True, it includes two steps to implement the incremental prediction.</span>
<span class="gp">&gt;&gt;&gt; </span><span class="c1"># Step 1: set is_first_iteration=True, and input the full sequence length&#39;s state.</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">batch_valid_length</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">2</span><span class="p">,)),</span> <span class="n">mstype</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">init_reset</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">([</span><span class="kc">True</span><span class="p">],</span> <span class="n">mstype</span><span class="o">.</span><span class="n">bool_</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="c1"># Set is_first_iteration=True to generate the full memory states</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">model</span> <span class="o">=</span> <span class="n">TransformerEncoder</span><span class="p">(</span><span class="n">batch_size</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">hidden_size</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">ffn_hidden_size</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">seq_length</span><span class="o">=</span><span class="mi">16</span><span class="p">,</span>
<span class="gp">... </span>                           <span class="n">num_heads</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">num_layers</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">use_past</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">model</span><span class="o">.</span><span class="n">add_flags_recursive</span><span class="p">(</span><span class="n">is_first_iteration</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">hidden</span><span class="p">,</span> <span class="n">past</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">encoder_input_value</span><span class="p">,</span> <span class="n">encoder_input_mask</span><span class="p">,</span> <span class="n">init_reset</span><span class="p">,</span> <span class="n">batch_valid_length</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">hidden</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 16, 8)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">past</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 2, 4, 16)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">past</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 2, 16, 4)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">encoder_input_value</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">8</span><span class="p">)),</span> <span class="n">mstype</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">encoder_input_mask</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">16</span><span class="p">)),</span> <span class="n">mstype</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">init_reset</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">([</span><span class="kc">False</span><span class="p">],</span> <span class="n">mstype</span><span class="o">.</span><span class="n">bool_</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="c1"># Step 2: set is_first_iteration=False, and pass the single word to run the prediction rather than the full</span>
<span class="gp">&gt;&gt;&gt; </span><span class="c1"># sequence.</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">model</span><span class="o">.</span><span class="n">add_flags_recursive</span><span class="p">(</span><span class="n">is_first_iteration</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">hidden</span><span class="p">,</span> <span class="n">past</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">encoder_input_value</span><span class="p">,</span> <span class="n">encoder_input_mask</span><span class="p">,</span> <span class="n">init_reset</span><span class="p">,</span> <span class="n">batch_valid_length</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">hidden</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 1, 8)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">past</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 2, 4, 16)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">past</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 2, 16, 4)</span>
</pre></div>
</div>
</dd></dl>

<dl class="class">
<dt id="mindspore.nn.transformer.TransformerDecoder">
<em class="property">class </em><code class="sig-prename descclassname">mindspore.nn.transformer.</code><code class="sig-name descname">TransformerDecoder</code><span class="sig-paren">(</span><em class="sig-param">num_layers</em>, <em class="sig-param">batch_size</em>, <em class="sig-param">hidden_size</em>, <em class="sig-param">ffn_hidden_size</em>, <em class="sig-param">src_seq_length</em>, <em class="sig-param">tgt_seq_length</em>, <em class="sig-param">num_heads</em>, <em class="sig-param">attention_dropout_rate=0.1</em>, <em class="sig-param">hidden_dropout_rate=0.1</em>, <em class="sig-param">post_layernorm_residual=False</em>, <em class="sig-param">layernorm_compute_type=mstype.float32</em>, <em class="sig-param">softmax_compute_type=mstype.float32</em>, <em class="sig-param">param_init_type=mstype.float32</em>, <em class="sig-param">hidden_act=&quot;gelu&quot;</em>, <em class="sig-param">lambda_func=None</em>, <em class="sig-param">use_past=False</em>, <em class="sig-param">offset=0</em>, <em class="sig-param">moe_config=default_moe_config</em>, <em class="sig-param">parallel_config=default_transformer_config</em><span class="sig-paren">)</span><a class="headerlink" href="#mindspore.nn.transformer.TransformerDecoder" title="Permalink to this definition">¶</a></dt>
<dd><p>Transformer中的解码器模块，为多层堆叠的 <cite>TransformerDecoderLayer</cite> ，包括多头自注意力层、交叉注意力层和前馈层。</p>
<p><strong>参数：</strong></p>
<ul class="simple">
<li><p><strong>batch_size</strong> (int) - 表示输入Tensor的批次大小。</p></li>
<li><p><strong>num_layers</strong> (int) - 表示 <cite>TransformerDecoderLayer</cite> 的层数。</p></li>
<li><p><strong>hidden_size</strong> (int) - 表示输入的隐藏大小。</p></li>
<li><p><strong>ffn_hidden_size</strong> (int) - 表示前馈层中bottleneck的隐藏大小。</p></li>
<li><p><strong>src_seq_length</strong> (int) - 表示输入源序列长度。</p></li>
<li><p><strong>tgt_seq_length</strong> (int) - 表示输入目标序列长度。</p></li>
<li><p><strong>num_heads</strong> (int) - 表示注意力头的数量。</p></li>
<li><p><strong>hidden_dropout_rate</strong> (float) - 表示作用在隐藏层输出的丢弃率。默认值：0.1</p></li>
<li><p><strong>attention_dropout_rate</strong> (float) - 表示注意力score的丢弃率。默认值：0.1</p></li>
<li><p><strong>post_layernorm_residual</strong> (bool) - 表示是否在LayerNorm之前使用残差，即是否选择残差为Post-LayerNorm或者Pre-LayerNorm。默认值：False</p></li>
<li><p><strong>hidden_act</strong> (str) - 表示内部前馈层的激活函数。其值可为’relu’、’relu6’、’tanh’、’gelu’、’fast_gelu’、’elu’、’sigmoid’、’prelu’、’leakyrelu’、’hswish’、’hsigmoid’、’logsigmoid’等等。默认值：gelu。</p></li>
<li><p><strong>layernorm_compute_type</strong> (dtype.Number) - 表示LayerNorm的计算类型。其值应为dtype.float32或dtype.float16。默认值为dtype.float32。</p></li>
<li><p><strong>softmax_compute_type</strong> (dtype.Number) - 表示注意力中softmax的计算类型。其值应为dtype.float32或dtype.float16。默认值为mstype.float32。</p></li>
<li><p><strong>param_init_type</strong> (dtype.Number) - 表示模块的参数初始化类型。其值应为dtype.float32或dtype.float16。默认值为dtype.float32。</p></li>
<li><p><strong>offset</strong> (int) - 表示 <cite>decoder</cite> 的初始层索引偏移值。其用于设置梯度聚合的融合值和流水线并行的stage值，使其不与编码器层的相关属性重叠。</p></li>
<li><p><strong>lambda_func</strong> - 表示确定梯度融合索引、pipeline阶段和重计算属性的函数。如果用户想确定pipeline阶段和梯度聚合融合，用户可以传递一个接受 <cite>network</cite> 、 <cite>layer_id</cite> 、 <cite>offset</cite> 、 <cite>parallel_config</cite> 和 <cite>layers</cite> 的函数。 <cite>network(Cell)</cite> 表示transformer块， <cite>layer_id(int)</cite> 表示当前模块的层索引，从零开始计数， <cite>offset(int)</cite> 表示如果网络中还有其他模块，则layer_index需要一个偏置。pipeline的默认设置为： <cite>(layer_id + offset) // (layers / pipeline_stage)</cite> 。默认值：None</p></li>
<li><p><strong>moe_config</strong> (MoEConfig) - 表示MoE (Mixture of Expert)的配置。</p></li>
<li><p><strong>parallel_config</strong> (TransformerOpParallelConfig) - 表示并行配置。默认值为 <cite>default_transformer_config</cite> ，表示带有默认参数的 <cite>TransformerOpParallelConfig</cite> 实例。</p></li>
</ul>
<p><strong>输入：</strong></p>
<ul class="simple">
<li><p><strong>hidden_stats</strong> (Tensor) - shape为[batch_size, seq_length, hidden_size]或[batch_size * seq_length, hidden_size]的输入tensor。</p></li>
<li><p><strong>attention_mask</strong> (Tensor) - shape为[batch_size, seq_length, seq_length]的解码器的注意力掩码。</p></li>
<li><p><strong>encoder_output</strong> (Tensor) - shape为[batch_size, seq_length, hidden_size]或[batch_size * seq_length, hidden_size]的编码器的输出。</p>
</li>
<li><p><strong>memory_mask</strong> (Tensor) - shape为[batch, tgt_seq_length, src_seq_length]的交叉注意力的memory掩码，其中tgt_seq_length表示解码器的长度。注：当网络位于最外层时，此参数不能通过None传递。默认值为None。</p></li>
<li><p><strong>init_reset</strong> (Tensor) - shape为[1]的bool tensor，用于清除增量预测中使用的past key参数和past value参数。仅当use_past为True时有效。默认值为True。</p></li>
<li><p><strong>batch_valid_length</strong> (Tensor) - shape为[batch_size]的Int32 tensor，表示过去所计算的索引。当use_past为True时，它用于增量预测。默认值为None。</p></li>
</ul>
<p><strong>输出：</strong></p>
<p>Tuple，表示一个包含(<cite>output</cite>, <cite>layer_present</cite>)的元组。</p>
<ul class="simple">
<li><p><strong>output</strong> (Tensor) - 输出的logit。shape为[batch, tgt_seq_length, hidden_size]或[batch * tgt_seq_length, hidden_size]。</p></li>
<li><p><strong>layer_present</strong> (Tuple) - 大小为层数的元组，其中每个元组都是shape为((batch_size, num_heads, size_per_head, tgt_seq_length)或(batch_size, num_heads, tgt_seq_length, size_per_head)的自注意力中的投影key向量和value向量的tensor，或者是shape为(batch_size, num_heads, size_per_head, src_seq_length)或(batch_size, num_heads, src_seq_length, size_per_head))的交叉注意力中的投影key向量和value向量的tensor。</p></li>
</ul>
<p><strong>支持平台：</strong></p>
<p><code class="docutils literal notranslate"><span class="pre">Ascend</span></code> <code class="docutils literal notranslate"><span class="pre">GPU</span></code></p>
<p><strong>样例：</strong></p>
<div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">mindspore</span> <span class="kn">import</span> <span class="n">dtype</span> <span class="k">as</span> <span class="n">mstype</span>
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">mindspore.nn.transformer</span> <span class="kn">import</span> <span class="n">TransformerDecoder</span>
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">mindspore</span> <span class="kn">import</span> <span class="n">Tensor</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">model</span> <span class="o">=</span> <span class="n">TransformerDecoder</span><span class="p">(</span><span class="n">batch_size</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">num_layers</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">hidden_size</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">ffn_hidden_size</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span>
<span class="gp">... </span>                           <span class="n">num_heads</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">src_seq_length</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span> <span class="n">tgt_seq_length</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">encoder_input_value</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">20</span><span class="p">,</span> <span class="mi">64</span><span class="p">)),</span> <span class="n">mstype</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">decoder_input_value</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">64</span><span class="p">)),</span> <span class="n">mstype</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">decoder_input_mask</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">)),</span> <span class="n">mstype</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">memory_mask</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">20</span><span class="p">)),</span> <span class="n">mstype</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">output</span><span class="p">,</span> <span class="n">past</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">decoder_input_value</span><span class="p">,</span> <span class="n">decoder_input_mask</span><span class="p">,</span> <span class="n">encoder_input_value</span><span class="p">,</span> <span class="n">memory_mask</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">output</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 10, 64)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">past</span><span class="p">))</span>
<span class="go">1</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">past</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 2, 32, 10)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">past</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 2, 10, 32)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">past</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 2, 32, 20)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">past</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">3</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 2, 20, 32)</span>
</pre></div>
</div>
</dd></dl>

<dl class="class">
<dt id="mindspore.nn.transformer.TransformerEncoderLayer">
<em class="property">class </em><code class="sig-prename descclassname">mindspore.nn.transformer.</code><code class="sig-name descname">TransformerEncoderLayer</code><span class="sig-paren">(</span><em class="sig-param">batch_size</em>, <em class="sig-param">hidden_size</em>, <em class="sig-param">ffn_hidden_size</em>, <em class="sig-param">num_heads</em>, <em class="sig-param">seq_length</em>, <em class="sig-param">attention_dropout_rate=0.1</em>, <em class="sig-param">hidden_dropout_rate=0.1</em>, <em class="sig-param">post_layernorm_residual=False</em>, <em class="sig-param">layernorm_compute_type=mstype.float32</em>, <em class="sig-param">softmax_compute_type=mstype.float32</em>, <em class="sig-param">param_init_type=mstype.float32</em>, <em class="sig-param">hidden_act=&quot;gelu&quot;</em>, <em class="sig-param">use_past=False</em>, <em class="sig-param">moe_config=default_moe_config</em>, <em class="sig-param">parallel_config=default_dpmp_config</em><span class="sig-paren">)</span><a class="headerlink" href="#mindspore.nn.transformer.TransformerEncoderLayer" title="Permalink to this definition">¶</a></dt>
<dd><p>Transformer的编码器层。Transformer的编码器层上的单层的实现，包括多头注意力层和前馈层。</p>
<p><strong>参数：</strong></p>
<ul class="simple">
<li><p><strong>batch_size</strong> (int) - 表示输入Tensor的批次大小。</p></li>
<li><p><strong>hidden_size</strong> (int) - 表示输入的隐藏大小。</p></li>
<li><p><strong>seq_length</strong> (int) - 表示输入序列长度。</p></li>
<li><p><strong>ffn_hidden_size</strong> (int) - 表示前馈层中bottleneck的隐藏大小。</p></li>
<li><p><strong>num_heads</strong> (int) - 表示注意力头的数量。</p></li>
<li><p><strong>hidden_dropout_rate</strong> (float) - 表示作用在隐藏层输出的丢弃率。默认值：0.1</p></li>
<li><p><strong>attention_dropout_rate</strong> (float) - 表示注意力score的丢弃率。默认值：0.1</p></li>
<li><p><strong>post_layernorm_residual</strong> (bool) - 表示是否在LayerNorm之前使用残差，即是否选择残差为Post-LayerNorm或者Pre-LayerNorm。默认值：False</p></li>
<li><p><strong>hidden_act</strong> (str) - 表示内部前馈层的激活函数。其值可为’relu’、’relu6’、’tanh’、’gelu’、’fast_gelu’、’elu’、’sigmoid’、’prelu’、’leakyrelu’、’hswish’、’hsigmoid’、’logsigmoid’等等。默认值：gelu。</p></li>
<li><p><strong>layernorm_compute_type</strong> (dtype.Number) - 表示LayerNorm的计算类型。其值应为dtype.float32或dtype.float16。默认值为dtype.float32。</p></li>
<li><p><strong>softmax_compute_type</strong> (dtype.Number) - 表示注意力中softmax的计算类型。其值应为dtype.float32或dtype.float16。默认值为mstype.float32。</p></li>
<li><p><strong>param_init_type</strong> (dtype.Number) - 表示模块的参数初始化类型。其值应为dtype.float32或dtype.float16。默认值为dtype.float32。</p></li>
<li><p><strong>use_past</strong> (bool) - 使用过去状态进行计算，用于增量预测。例如，如果我们有两个单词，想生成十个或以上单词。我们只需要计算一次这两个单词的状态，然后逐个生成下一个单词。当use_past为True时，有两个步骤可以运行预测。第一步是通过 <cite>model.add_flags_recursive(is_first_iteration=True)</cite> 将is_first_iteration设为True，并传递完整的输入。然后，通过 <cite>model.add_flags_recursive(is_first_iteration=False)</cite> 将is_first_iteration设为False。此时，传递step的输入tensor，并对其进行环回。默认值：False</p></li>
<li><p><strong>moe_config</strong> (MoEConfig) - 表示MoE (Mixture of Expert)的配置。</p></li>
<li><p><strong>parallel_config</strong> (OpParallelConfig) - 表示并行配置。默认值为 <cite>default_dpmp_config</cite> ，表示一个带有默认参数的 <cite>OpParallelConfig</cite> 实例。</p></li>
</ul>
<p><strong>输入：</strong></p>
<ul class="simple">
<li><p><strong>x</strong> (Tensor) - Float Tensor。如果use_past为False或者is_first_iteration为True，shape应为[batch_size, seq_length, hidden_size]或者[batch_size * seq_length, hidden_size]。否则，shape应为[batch_size, 1, hidden_size]。</p></li>
<li><p><strong>input_mask</strong> (Tensor) - Float tensor。use_past为False或者is_first_iteration为True时，表示shape为[batch_size, seq_length, seq_length]的注意力掩码。否则，shape应为[batch_size, 1, hidden_size]。</p></li>
<li><p><strong>init_reset</strong> (Tensor) - shape为[1]的bool tensor，用于清除增量预测中使用的past key参数和past value参数。仅当use_past为True时有效。默认值为True。</p></li>
<li><p><strong>batch_valid_length</strong> (Tensor) - shape为[batch_size]的Int32 tensor，表示过去所计算的索引。当use_past为True时，它用于增量预测。默认值为None。</p></li>
</ul>
<p><strong>输出：</strong></p>
<p>Tuple，表示一个包含(<cite>output</cite>, <cite>layer_present</cite>)的元组。</p>
<ul class="simple">
<li><p><strong>output</strong> (Tensor) - use_past为False或is_first_iteration为True时，表示shape为(batch_size, seq_length, hidden_size)或(batch_size * seq_length, hidden_size)的层输出的float tensor。否则，shape将为(batch_size, 1, hidden_size)。</p></li>
<li><p><strong>layer_present</strong> (Tuple) - 表示shape为((batch_size, num_heads, size_per_head, seq_length)或(batch_size, num_heads, seq_length, size_per_head))的投影key向量和value向量的Tensor的元组。</p></li>
</ul>
<p><strong>支持平台：</strong></p>
<p><code class="docutils literal notranslate"><span class="pre">Ascend</span></code> <code class="docutils literal notranslate"><span class="pre">GPU</span></code></p>
<p><strong>样例：</strong></p>
<div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">mindspore</span> <span class="kn">import</span> <span class="n">dtype</span> <span class="k">as</span> <span class="n">mstype</span>
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">mindspore.nn.transformer</span> <span class="kn">import</span> <span class="n">TransformerEncoderLayer</span>
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">mindspore</span> <span class="kn">import</span> <span class="n">Tensor</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">model</span> <span class="o">=</span> <span class="n">TransformerEncoderLayer</span><span class="p">(</span><span class="n">batch_size</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">hidden_size</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">ffn_hidden_size</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">seq_length</span><span class="o">=</span><span class="mi">16</span><span class="p">,</span>
<span class="gp">... </span>                                <span class="n">num_heads</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">encoder_input_value</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">16</span><span class="p">,</span> <span class="mi">8</span><span class="p">)),</span> <span class="n">mstype</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">encoder_input_mask</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">16</span><span class="p">,</span> <span class="mi">16</span><span class="p">)),</span> <span class="n">mstype</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">output</span><span class="p">,</span> <span class="n">past</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">encoder_input_value</span><span class="p">,</span> <span class="n">encoder_input_mask</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">output</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 16, 8)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">past</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 2, 4, 16)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">past</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 2, 16, 4)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="c1"># When use use_past=True, it includes two steps to implement the incremental prediction.</span>
<span class="gp">&gt;&gt;&gt; </span><span class="c1"># Step 1: set is_first_iteration=True, and input the full sequence length&#39;s state.</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">batch_valid_length</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">2</span><span class="p">,)),</span> <span class="n">mstype</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">init_reset</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">([</span><span class="kc">True</span><span class="p">],</span> <span class="n">mstype</span><span class="o">.</span><span class="n">bool_</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="c1"># Set is_first_iteration=True to generate the full memory states</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">model</span> <span class="o">=</span> <span class="n">TransformerEncoderLayer</span><span class="p">(</span><span class="n">batch_size</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">hidden_size</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">ffn_hidden_size</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">seq_length</span><span class="o">=</span><span class="mi">16</span><span class="p">,</span>
<span class="gp">... </span>                                <span class="n">num_heads</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">use_past</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">model</span><span class="o">.</span><span class="n">add_flags_recursive</span><span class="p">(</span><span class="n">is_first_iteration</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">hidden</span><span class="p">,</span> <span class="n">past</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">encoder_input_value</span><span class="p">,</span> <span class="n">encoder_input_mask</span><span class="p">,</span> <span class="n">init_reset</span><span class="p">,</span> <span class="n">batch_valid_length</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">hidden</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 16, 8)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">past</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 2, 4, 16)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">past</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 2, 16, 4)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">encoder_input_value</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">8</span><span class="p">)),</span> <span class="n">mstype</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">encoder_input_mask</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">16</span><span class="p">)),</span> <span class="n">mstype</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">init_reset</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">([</span><span class="kc">False</span><span class="p">],</span> <span class="n">mstype</span><span class="o">.</span><span class="n">bool_</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="c1"># Step 2: set is_first_iteration=False, and pass the single word to run the prediction rather than the full</span>
<span class="gp">&gt;&gt;&gt; </span><span class="c1"># sequence.</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">model</span><span class="o">.</span><span class="n">add_flags_recursive</span><span class="p">(</span><span class="n">is_first_iteration</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">hidden</span><span class="p">,</span> <span class="n">past</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">encoder_input_value</span><span class="p">,</span> <span class="n">encoder_input_mask</span><span class="p">,</span> <span class="n">init_reset</span><span class="p">,</span> <span class="n">batch_valid_length</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">hidden</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 1, 8)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">past</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 2, 4, 16)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">past</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 2, 16, 4)</span>
</pre></div>
</div>
</dd></dl>

<dl class="class">
<dt id="mindspore.nn.transformer.TransformerDecoderLayer">
<em class="property">class </em><code class="sig-prename descclassname">mindspore.nn.transformer.</code><code class="sig-name descname">TransformerDecoderLayer</code><span class="sig-paren">(</span><em class="sig-param">hidden_size</em>, <em class="sig-param">ffn_hidden_size</em>, <em class="sig-param">num_heads</em>, <em class="sig-param">batch_size</em>, <em class="sig-param">src_seq_length</em>, <em class="sig-param">tgt_seq_length</em>, <em class="sig-param">attention_dropout_rate=0.1</em>, <em class="sig-param">hidden_dropout_rate=0.1</em>, <em class="sig-param">post_layernorm_residual=False</em>, <em class="sig-param">use_past=False</em>, <em class="sig-param">layernorm_compute_type=mstype.float32</em>, <em class="sig-param">softmax_compute_type=mstype.float32</em>, <em class="sig-param">param_init_type=mstype.float32</em>, <em class="sig-param">hidden_act=&quot;gelu&quot;</em>, <em class="sig-param">moe_config=default_moe_config</em>, <em class="sig-param">parallel_config=default_dpmp_config</em><span class="sig-paren">)</span><a class="headerlink" href="#mindspore.nn.transformer.TransformerDecoderLayer" title="Permalink to this definition">¶</a></dt>
<dd><p>Transformer的解码器层。Transformer的解码器层上的单层的实现，包括自注意力层、交叉注意力层和前馈层。当encoder_output为None时，交叉注意力将无效。</p>
<p><strong>参数：</strong></p>
<ul class="simple">
<li><p><strong>batch_size</strong> (int) - 表示输入Tensor的批次大小。</p></li>
<li><p><strong>hidden_size</strong> (int)：表示输入的隐藏大小。</p></li>
<li><p><strong>src_seq_length</strong> (int) - 表示输入源序列长度。</p></li>
<li><p><strong>tgt_seq_length</strong> (int) - 表示输入目标序列长度。</p></li>
<li><p><strong>ffn_hidden_size</strong> (int) - 表示前馈层中bottleneck的隐藏大小。</p></li>
<li><p><strong>num_heads</strong> (int) - 表示注意力头的数量。</p></li>
<li><p><strong>hidden_dropout_rate</strong> (float) - 表示作用在隐藏层输出的丢弃率。默认值：0.1</p></li>
<li><p><strong>attention_dropout_rate</strong> (float) - 表示注意力score的丢弃率。默认值：0.1</p></li>
<li><p><strong>post_layernorm_residual</strong> (bool) - 表示是否在LayerNorm之前使用残差，即是否选择残差为Post-LayerNorm或者Pre-LayerNorm。默认值：False</p></li>
<li><p><strong>hidden_act</strong> (str) - 表示内部前馈层的激活函数。其值可为’relu’、’relu6’、’tanh’、’gelu’、’fast_gelu’、’elu’、’sigmoid’、’prelu’、’leakyrelu’、’hswish’、’hsigmoid’、’logsigmoid’等等。默认值：gelu。</p></li>
<li><p><strong>layernorm_compute_type</strong> (dtype.Number) - 表示LayerNorm的计算类型。其值应为dtype.float32或dtype.float16。默认值为dtype.float32。</p></li>
<li><p><strong>softmax_compute_type</strong> (dtype.Number) - 表示注意力中softmax的计算类型。其值应为dtype.float32或dtype.float16。默认值为mstype.float32。</p></li>
<li><p><strong>param_init_type</strong> (dtype.Number) - 表示模块的参数初始化类型。其值应为dtype.float32或dtype.float16。默认值为dtype.float32。</p></li>
<li><p><strong>use_past</strong> (bool) - 使用过去状态进行计算，用于增量预测。默认值：False</p></li>
<li><p><strong>moe_config</strong> (MoEConfig) - 表示MoE (Mixture of Expert)的配置。</p></li>
<li><p><strong>parallel_config</strong> (OpParallelConfig) - 表示并行配置。默认值为 <cite>default_dpmp_config</cite> ，表示一个带有默认参数的 <cite>OpParallelConfig</cite> 实例。</p></li>
</ul>
<p><strong>输入：</strong></p>
<ul class="simple">
<li><p><strong>hidden_stats</strong> (Tensor) - shape为[batch_size, tgt_seq_length, hidden_size]或[batch_size * tgt_seq_length, hidden_size]的输入tensor。</p></li>
<li><p><strong>decoder_mask</strong> (Tensor) - shape为[batch_size, src_seq_length, seq_length]的解码器的注意力掩码。</p></li>
<li><p><strong>encoder_output</strong> (Tensor) - shape为[batch_size, seq_length, hidden_size]或[batch_size * seq_length, hidden_size]的编码器的输出。注：当网络位于最外层时，此参数不能通过None传递。默认值为None。</p></li>
<li><p><strong>memory_mask</strong> (Tensor) - shape为[batch, tgt_seq_length, src_seq_length]的交叉注意力的memory掩码，其中tgt_seq_length表示解码器的长度。注：当网络位于最外层时，此参数不能通过None传递。默认值为None。</p></li>
<li><p><strong>init_reset</strong> (Tensor) - shape为[1]的bool tensor，用于清除增量预测中使用的past key参数和past value参数。仅当use_past为True时有效。默认值为True。</p></li>
<li><p><strong>batch_valid_length</strong> (Tensor) - shape为[batch_size]的Int32 tensor，表示过去所计算的索引。当use_past为True时，它用于增量预测。默认值为None。</p></li>
</ul>
<p><strong>输出：</strong></p>
<p>Tuple，表示一个包含(<cite>output</cite>, <cite>layer_present</cite>)的元组。</p>
<ul class="simple">
<li><p><strong>output</strong> (Tensor) - 此层的输出logit。shape为[batch, seq_length, hidden_size]或[batch * seq_length, hidden_size]。</p></li>
<li><p><strong>layer_present</strong> (Tensor) - 元组，其中每个元组都是shape为((batch_size, num_heads, size_per_head, tgt_seq_length)或(batch_size, num_heads, tgt_seq_length, size_per_head)的自注意力中的投影key向量和value向量的tensor，或者是shape为(batch_size, num_heads, size_per_head, src_seq_length)或(batch_size, num_heads, src_seq_length, size_per_head))的交叉注意力中的投影key向量和value向量的tensor。</p></li>
</ul>
<p><strong>支持平台：</strong></p>
<p><code class="docutils literal notranslate"><span class="pre">Ascend</span></code> <code class="docutils literal notranslate"><span class="pre">GPU</span></code></p>
<p><strong>样例：</strong></p>
<div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">mindspore</span> <span class="kn">import</span> <span class="n">dtype</span> <span class="k">as</span> <span class="n">mstype</span>
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">mindspore.nn.transformer</span> <span class="kn">import</span> <span class="n">TransformerDecoderLayer</span>
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">mindspore</span> <span class="kn">import</span> <span class="n">Tensor</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">model</span> <span class="o">=</span> <span class="n">TransformerDecoderLayer</span><span class="p">(</span><span class="n">batch_size</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">hidden_size</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">ffn_hidden_size</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">num_heads</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span>
<span class="gp">... </span>                                <span class="n">src_seq_length</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span> <span class="n">tgt_seq_length</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">encoder_input_value</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">20</span><span class="p">,</span> <span class="mi">64</span><span class="p">)),</span> <span class="n">mstype</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">decoder_input_value</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">64</span><span class="p">)),</span> <span class="n">mstype</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">decoder_input_mask</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">)),</span> <span class="n">mstype</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">memory_mask</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">20</span><span class="p">)),</span> <span class="n">mstype</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">output</span><span class="p">,</span> <span class="n">past</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">decoder_input_value</span><span class="p">,</span> <span class="n">decoder_input_mask</span><span class="p">,</span> <span class="n">encoder_input_value</span><span class="p">,</span> <span class="n">memory_mask</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">output</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 10, 64)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">past</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 2, 32, 10)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">past</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 2, 10, 32)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">past</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 2, 32, 20)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">past</span><span class="p">[</span><span class="mi">3</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 2, 20, 32)</span>
</pre></div>
</div>
</dd></dl>

<dl class="class">
<dt id="mindspore.nn.transformer.Transformer">
<em class="property">class </em><code class="sig-prename descclassname">mindspore.nn.transformer.</code><code class="sig-name descname">Transformer</code><span class="sig-paren">(</span><em class="sig-param">hidden_size</em>, <em class="sig-param">batch_size</em>, <em class="sig-param">ffn_hidden_size</em>, <em class="sig-param">src_seq_length</em>, <em class="sig-param">tgt_seq_length</em>, <em class="sig-param">encoder_layers=3</em>, <em class="sig-param">decoder_layers=3</em>, <em class="sig-param">num_heads=2</em>, <em class="sig-param">attention_dropout_rate=0.1</em>, <em class="sig-param">hidden_dropout_rate=0.1</em>, <em class="sig-param">hidden_act=&quot;gelu&quot;</em>, <em class="sig-param">post_layernorm_residual=False</em>, <em class="sig-param">layernorm_compute_type=mstype.float32</em>, <em class="sig-param">softmax_compute_type=mstype.float32</em>, <em class="sig-param">param_init_type=mstype.float32</em>, <em class="sig-param">lambda_func=None</em>, <em class="sig-param">use_past=False</em>, <em class="sig-param">moe_config=default_moe_config</em>, <em class="sig-param">parallel_config=default_transformer_config</em><span class="sig-paren">)</span><a class="headerlink" href="#mindspore.nn.transformer.Transformer" title="Permalink to this definition">¶</a></dt>
<dd><p>Transformer模块，包括编码器和解码器。与原始的实现方式的区别在于该模块在实行层归一化之前使用了残差加法。默认的激活层为 <cite>gelu</cite> 。
详细信息可参考 <a class="reference external" href="https://arxiv.org/pdf/1706.03762v5.pdf">Attention Is All You Need</a> 。</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>这是一个实验接口，可能会被更改或者删除。</p>
</div>
<p><strong>参数：</strong></p>
<ul class="simple">
<li><p><strong>batch_size</strong> (int) - 表示输入的批次大小。</p></li>
<li><p><strong>encoder_layers</strong> (int) - 表示 <cite>TransformerEncoderLayer</cite> 的层数。</p></li>
<li><p><strong>decoder_layers</strong> (int) - 表示 <cite>TransformerDecoderLayer</cite> 的层数。</p></li>
<li><p><strong>hidden_size</strong> (int) - 表示输入向量的大小。</p></li>
<li><p><strong>ffn_hidden_size</strong> (int) - 表示前馈层中bottleneck的隐藏大小。</p></li>
<li><p><strong>src_seq_length</strong> (int) - 表示编码器的输入Tensor的seq_length。</p></li>
<li><p><strong>tgt_seq_length</strong> (int) - 表示解码器的输入Tensor的seq_length。</p></li>
<li><p><strong>num_heads</strong> (int) - 表示注意力头的数量。默认值：2</p></li>
<li><p><strong>hidden_dropout_rate</strong> (float) - 表示作用在隐藏层输出的丢弃率。默认值：0.1</p></li>
<li><p><strong>attention_dropout_rate</strong> (float) - 表示注意力score的丢弃率。默认值：0.1</p></li>
<li><p><strong>post_layernorm_residual</strong> (bool) - 表示是否在LayerNorm之前使用残差，即是否选择残差为Post-LayerNorm或者Pre-LayerNorm。默认值：False</p></li>
<li><p><strong>layernorm_compute_type</strong> (dtype.Number) - 表示LayerNorm的计算类型。其值应为dtype.float32或dtype.float16。默认值为dtype.float32。</p></li>
<li><p><strong>softmax_compute_type</strong> (dtype.Number) - 表示注意力机制中softmax的计算类型。其值应为dtype.float32或dtype.float16。默认值为mstype.float32。</p></li>
<li><p><strong>param_init_type</strong> (dtype.Number) - 表示模块的参数初始化类型。其值应为dtype.float32或dtype.float16。默认值为dtype.float32。</p></li>
<li><p><strong>hidden_act</strong> (str) - 表示前馈层的激活行为。其值可为’relu’、’relu6’、’tanh’、’gelu’、’fast_gelu’、’elu’、’sigmoid’、’prelu’、’leakyrelu’、’hswish’、’hsigmoid’、’logsigmoid’等等。默认值：gelu。</p></li>
<li><p><strong>moe_config</strong> (MoEConfig) - 表示MoE (Mixture of Expert)的配置。</p></li>
<li><p><strong>lambda_func</strong> - 表示设置融合索引、pipeline阶段和重计算属性的函数。如果用户想确定pipeline阶段和梯度融合，用户可以传递一个接受 <cite>network</cite> 、 <cite>layer_id</cite> 、 <cite>offset</cite> 、 <cite>parallel_config</cite> 和 <cite>layers</cite> 的函数。 <cite>network(Cell)</cite> 表示transformer块， <cite>layer_id(int)</cite> 表示当前模块的层索引，从零开始计数， <cite>offset(int)</cite> 表示如果网络中还有其他模块，则layer_id需要一个偏移。pipeline的默认设置为： <cite>(layer_id + offset) // ((encoder_layers + decoder_length) / pipeline_stage)</cite> 。</p></li>
<li><p><strong>parallel_config</strong> (TransformerOpParallelConfig) - 表示并行配置。默认值为 <cite>default_transformer_config</cite> ，表示带有默认参数的 <cite>TransformerOpParallelConfig</cite> 实例。</p></li>
</ul>
<p><strong>输入：</strong></p>
<ul class="simple">
<li><p><strong>encoder_inputs</strong> (Tensor) - shape为[batch_size, seq_length, hidden_size]或[batch_size * seq_length, hidden_size]的输入Tensor。</p></li>
<li><p><strong>encoder_masks</strong> (Tensor) - shape为[batch_size, seq_length, seq_length]的解码器的注意力掩码。</p></li>
<li><p><strong>decoder_inputs</strong> (Tensor) - shape为[batch_size, seq_length, hidden_size]或[batch_size * seq_length, hidden_size]的编码器的输出。如果解码器层数为0，则此值应为None。</p></li>
<li><p><strong>decoder_masks</strong> (Tensor) - shape为[batch_size, seq_length, seq_length]的解码器的注意力掩码。</p></li>
<li><p><strong>memory_mask</strong> (Tensor) - shape为[batch, tgt_seq_length,  src_seq_length]的交叉注意力的memory掩码，其中tgt_seq_length表示解码器的长度。如果解码器层为0，则shape为[batch_size, seq_length, hidden_size]的编码器的输出应为None。</p></li>
<li><p><strong>init_reset</strong> (Tensor) - shape为[1]的bool tensor，用于清除增量预测中使用的past key参数和past value参数。仅当use_past为True时有效。默认值为True。</p></li>
<li><p><strong>batch_valid_length</strong> (Tensor) - shape为[batch_size]的Int32 tensor，表示过去所计算的索引。当use_past为True时，它用于增量预测。默认值为None。</p></li>
</ul>
<p><strong>输出：</strong></p>
<p>Tuple，表示包含(<cite>output</cite>, <cite>encoder_layer_present</cite>, <cite>encoder_layer_present</cite>)的元组。</p>
<ul class="simple">
<li><p><strong>output</strong> (Tensor) - 如果只有编码器，则表示编码器层的输出logit。shape为[batch, src_seq_length, hidden_size] or [batch * src_seq_length, hidden_size]。如果有编码器和解码器，则输出来自于解码器层。shape为[batch, tgt_seq_length, hidden_size]或[batch * tgt_seq_length, hidden_size]。</p></li>
<li><p><strong>encoder_layer_present</strong> (Tuple) - 大小为num_layers的元组，其中每个元组都是shape为((batch_size, num_heads, size_per_head, src_seq_length)或(batch_size, num_heads, src_seq_length, size_per_head))的自注意力中的投影key向量和value向量的tensor。</p></li>
<li><p><strong>decoder_layer_present</strong> (Tuple) - 大小为num_layers的元组，其中每个元组都是shape为((batch_size, num_heads, size_per_head, tgt_seq_length)或(batch_size, num_heads, tgt_seq_length, size_per_head))的self attention中的投影key向量和value向量的tensor，或者是shape为(batch_size, num_heads, size_per_head, src_seq_length)或(batch_size, num_heads, src_seq_length, size_per_head))的交叉注意力中的投影key向量和value向量的tensor。如果未设置解码器，返回值将为None。</p></li>
</ul>
<p><strong>支持平台：</strong></p>
<p><code class="docutils literal notranslate"><span class="pre">Ascend</span></code> <code class="docutils literal notranslate"><span class="pre">GPU</span></code></p>
<p><strong>样例：</strong></p>
<div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">mindspore</span> <span class="kn">import</span> <span class="n">dtype</span> <span class="k">as</span> <span class="n">mstype</span>
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">mindspore.nn.transformer</span> <span class="kn">import</span> <span class="n">Transformer</span>
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">mindspore</span> <span class="kn">import</span> <span class="n">Tensor</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">model</span> <span class="o">=</span> <span class="n">Transformer</span><span class="p">(</span><span class="n">batch_size</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">encoder_layers</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">decoder_layers</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">hidden_size</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">ffn_hidden_size</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span>
<span class="gp">... </span>        <span class="n">src_seq_length</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span> <span class="n">tgt_seq_length</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">encoder_input_value</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">20</span><span class="p">,</span> <span class="mi">64</span><span class="p">)),</span> <span class="n">mstype</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">encoder_input_mask</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">20</span><span class="p">,</span> <span class="mi">20</span><span class="p">)),</span> <span class="n">mstype</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">decoder_input_value</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">64</span><span class="p">)),</span> <span class="n">mstype</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">decoder_input_mask</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">)),</span> <span class="n">mstype</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">memory_mask</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">20</span><span class="p">)),</span> <span class="n">mstype</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">output</span><span class="p">,</span> <span class="n">en_past</span><span class="p">,</span> <span class="n">de_past</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">encoder_input_value</span><span class="p">,</span> <span class="n">encoder_input_mask</span><span class="p">,</span> <span class="n">decoder_input_value</span><span class="p">,</span>
<span class="gp">... </span>                                 <span class="n">decoder_input_mask</span><span class="p">,</span> <span class="n">memory_mask</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">output</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 10, 64)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">en_past</span><span class="p">))</span>
<span class="go">1</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">de_past</span><span class="p">))</span>
<span class="go">2</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">en_past</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 2, 32, 20)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">en_past</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 2, 20, 32)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">de_past</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 2, 32, 10)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">de_past</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 2, 10, 32)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">de_past</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 2, 32, 20)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">de_past</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">3</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 2, 20, 32)</span>
</pre></div>
</div>
</dd></dl>

<dl class="class">
<dt id="mindspore.nn.transformer.TransformerOpParallelConfig">
<em class="property">class </em><code class="sig-prename descclassname">mindspore.nn.transformer.</code><code class="sig-name descname">TransformerOpParallelConfig</code><span class="sig-paren">(</span><em class="sig-param">data_parallel=1</em>, <em class="sig-param">model_parallel=1</em>, <em class="sig-param">expert_parallel=1</em>, <em class="sig-param">pipeline_stage=1</em>, <em class="sig-param">micro_batch_num=1</em>, <em class="sig-param">recompute=default_transformer_recompute_config</em>, <em class="sig-param">optimizer_shard=False</em>, <em class="sig-param">gradient_aggregation_group=4</em>, <em class="sig-param">vocab_emb_dp=True</em><span class="sig-paren">)</span><a class="headerlink" href="#mindspore.nn.transformer.TransformerOpParallelConfig" title="Permalink to this definition">¶</a></dt>
<dd><p>用于设置数据并行、模型并行等等并行配置的TransformerOpParallelConfig。</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>除recompute参数外，当用户未将auto_parallel_context设为 <cite>SEMI_AUTO_PARALLEL</cite> 或 <cite>AUTO_PARALLEL</cite> 时，其他参数将无效。
在训练时，micro_batch_num的值必须大于或等于equal to pipeline_stage的值。data_parallel*model_parallel  *pipeline_stage的值必须等于或小于总设备的数量。设置pipeline_stage和optimizer_shard时，其配置将覆盖auto_parallel_context的配置。</p>
</div>
<p><strong>参数：</strong></p>
<ul class="simple">
<li><p><strong>data_parallel</strong> (int) - 表示数据并行数。默认值：1。</p></li>
<li><p><strong>model_parallel</strong> (int) - 表示模型并行数。默认值：1。</p></li>
<li><p><strong>expert_parallel</strong> (int) - 表示专家并行数，只有在应用混合专家结构（MoE，Mixture of Experts）时才会生效。默认值：1.</p></li>
<li><p><strong>pipeline_stage</strong> (int) - 表示将Transformer切分成的stage数目。其值应为正数。默认值：1。</p></li>
<li><p><strong>micro_batch_num</strong> (int) - 表示用于pipeline训练的batch的微型大小。默认值：1。</p></li>
<li><p><strong>optimizer_shard</strong> (bool) - 表示是否使能优化器切分。默认值：False。</p></li>
<li><p><strong>gradient_aggregation_group</strong> (int) - 表示优化器切分的融合组大小。默认值：4。</p></li>
<li><p><strong>recompute</strong> (bool) - 表示是否启用transformer每层的的重计算。默认值：False。</p></li>
<li><p><strong>vocab_emb_dp</strong> (bool) - 表示Embedding表是否为数据并行，否则将在查找表中的第0维度根据模型并行度进行切分。默认值：True。</p></li>
</ul>
<p><strong>支持平台：</strong></p>
<p><code class="docutils literal notranslate"><span class="pre">Ascend</span></code> <code class="docutils literal notranslate"><span class="pre">GPU</span></code></p>
<p><strong>样例：</strong></p>
<div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">mindspore.nn.transformer</span> <span class="kn">import</span> <span class="n">TransformerRecomputeConfig</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">recompute_config</span><span class="o">=</span><span class="n">TransformerRecomputeConfig</span><span class="p">(</span><span class="n">recompute</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">parallel_optimizer_comm_recompute</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> \
<span class="gp">... </span>                                            <span class="n">mp_comm_recompute</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">recompute_slice_activation</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">config</span><span class="o">=</span><span class="n">TransformerOpParallelConfig</span><span class="p">(</span><span class="n">data_parallel</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">model_parallel</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">recompute</span><span class="o">=</span><span class="n">recompute_config</span><span class="p">)</span>
</pre></div>
</div>
<dl class="method">
<dt id="mindspore.nn.transformer.TransformerOpParallelConfig.dp_mp_config">
<code class="sig-name descname">dp_mp_config</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#mindspore.nn.transformer.TransformerOpParallelConfig.dp_mp_config" title="Permalink to this definition">¶</a></dt>
<dd><p>获取包含数据并行、模型并行度的DPMPlConfig。</p>
<p><strong>支持平台：</strong></p>
<p><code class="docutils literal notranslate"><span class="pre">Ascend</span></code> <code class="docutils literal notranslate"><span class="pre">GPU</span></code></p>
<p><strong>样例：</strong></p>
<div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">mindspore.nn.transformer</span> <span class="kn">import</span> <span class="n">TransformerOpParallelConfig</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">config</span><span class="o">=</span><span class="n">TransformerOpParallelConfig</span><span class="p">(</span><span class="n">data_parallel</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">model_parallel</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">vocab_emb_dp</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">parallel_config</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">dp_mp_config</span>
</pre></div>
</div>
</dd></dl>

<dl class="method">
<dt id="mindspore.nn.transformer.TransformerOpParallelConfig.embedding_dp_mp_config">
<code class="sig-name descname">embedding_dp_mp_config</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#mindspore.nn.transformer.TransformerOpParallelConfig.embedding_dp_mp_config" title="Permalink to this definition">¶</a></dt>
<dd><p>获取包含数据并行、模型并行和embedding并行度的EmbeddingParallelConfig。</p>
<p><strong>支持平台：</strong></p>
<p><code class="docutils literal notranslate"><span class="pre">Ascend</span></code> <code class="docutils literal notranslate"><span class="pre">GPU</span></code></p>
<p><strong>样例：</strong></p>
<div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="n">config</span><span class="o">=</span><span class="n">TransformerOpParallelConfig</span><span class="p">(</span><span class="n">data_parallel</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">model_parallel</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">vocab_emb_dp</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">parallel_config</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">embedding_dp_mp_config</span>
</pre></div>
</div>
</dd></dl>

</dd></dl>

<dl class="class">
<dt id="mindspore.nn.transformer.EmbeddingOpParallelConfig">
<em class="property">class </em><code class="sig-prename descclassname">mindspore.nn.transformer.</code><code class="sig-name descname">EmbeddingOpParallelConfig</code><span class="sig-paren">(</span><em class="sig-param">data_parallel=1</em>, <em class="sig-param">model_parallel=1</em>, <em class="sig-param">vocab_emb_dp=True</em><span class="sig-paren">)</span><a class="headerlink" href="#mindspore.nn.transformer.EmbeddingOpParallelConfig" title="Permalink to this definition">¶</a></dt>
<dd><p><cite>VocabEmbedding</cite> 类中的并行配置。当vocab_emb_dp为True时，设置Embedding查找为数据并行，其中model_parallel参数会被忽略。当vocab_emb_dp为False时，在Embedding表的第0轴进行按model_parallel的大小进行切分。</p>
<p><strong>参数：</strong></p>
<ul class="simple">
<li><p><strong>data_parallel</strong> (int) - 表示数据并行度。默认值：1。</p></li>
<li><p><strong>model_parallel</strong> (int) - 表示模型平行度。默认值：1。</p></li>
<li><p><strong>vocab_emb_dp</strong> (bool) - 表示模型并行或数据并行中的Shard embedding。默认值：True。</p></li>
</ul>
<p><strong>支持平台：</strong></p>
<p><code class="docutils literal notranslate"><span class="pre">Ascend</span></code> <code class="docutils literal notranslate"><span class="pre">GPU</span></code></p>
<p><strong>样例：</strong></p>
<div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">mindspore.nn.transformer</span> <span class="kn">import</span> <span class="n">EmbeddingOpParallelConfig</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">config</span><span class="o">=</span><span class="n">EmbeddingOpParallelConfig</span><span class="p">(</span><span class="n">data_parallel</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">model_parallel</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">vocab_emb_dp</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
</pre></div>
</div>
<dl class="method">
<dt id="mindspore.nn.transformer.EmbeddingOpParallelConfig.dp_mp_config">
<code class="sig-name descname">dp_mp_config</code><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#mindspore.nn.transformer.EmbeddingOpParallelConfig.dp_mp_config" title="Permalink to this definition">¶</a></dt>
<dd><p>获取包含有data_parallel和model_parallel属性的DPMPlConfig类。</p>
<p><strong>支持平台：</strong></p>
<p><code class="docutils literal notranslate"><span class="pre">Ascend</span></code> <code class="docutils literal notranslate"><span class="pre">GPU</span></code></p>
<p><strong>样例：</strong></p>
<div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">mindspore.nn.transformer</span> <span class="kn">import</span> <span class="n">EmbeddingOpParallelConfig</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">config</span><span class="o">=</span><span class="n">EmbeddingOpParallelConfig</span><span class="p">(</span><span class="n">data_parallel</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">model_parallel</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">vocab_emb_dp</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">parallel_config</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">dp_mp_config</span>
</pre></div>
</div>
</dd></dl>

</dd></dl>

<dl class="class">
<dt id="mindspore.nn.transformer.CrossEntropyLoss">
<em class="property">class </em><code class="sig-prename descclassname">mindspore.nn.transformer.</code><code class="sig-name descname">CrossEntropyLoss</code><span class="sig-paren">(</span><em class="sig-param">parallel_config=default_dpmp_config</em><span class="sig-paren">)</span><a class="headerlink" href="#mindspore.nn.transformer.CrossEntropyLoss" title="Permalink to this definition">¶</a></dt>
<dd><p>计算输入和输出之间的交叉熵损失。</p>
<p><strong>参数：</strong></p>
<ul class="simple">
<li><p><strong>parallel_config</strong> (OpParallelConfig) - 表示并行配置。默认值为 <cite>default_dpmp_config</cite> ，表示一个带有默认参数的 <cite>OpParallelConfig</cite> 实例。</p></li>
</ul>
<p><strong>输入：</strong></p>
<ul class="simple">
<li><p><strong>logits</strong> (Tensor) - shape为(N, C)的Tensor。表示的输出logits。其中N表示任意大小的维度，C表示类别个数。数据类型必须为float16或float32。</p></li>
<li><p><strong>labels</strong> (Tensor) - shape为(N, )的Tensor。表示样本的真实标签，其中每个元素的取值区间为[0,C)。</p></li>
<li><p><strong>input_mask</strong> (Tensor) - shape为(N, )的Tensor。input_mask表示是否有填充输入。1表示有效，0表示无效，其中元素值为0的位置不会计算进损失值。</p></li>
</ul>
<p><strong>输出：</strong></p>
<p>Tensor，表示对应的交叉熵损失。</p>
<p><strong>样例：</strong></p>
<div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">mindspore</span> <span class="kn">import</span> <span class="n">dtype</span> <span class="k">as</span> <span class="n">mstype</span>
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">mindspore.nn.transformer</span> <span class="kn">import</span> <span class="n">CrossEntropyLoss</span>
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">mindspore</span> <span class="kn">import</span> <span class="n">Tensor</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">loss</span> <span class="o">=</span> <span class="n">CrossEntropyLoss</span><span class="p">()</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">logits</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="mi">3</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">9</span><span class="p">,</span> <span class="mi">12</span><span class="p">,</span> <span class="mi">33</span><span class="p">,</span> <span class="mi">42</span><span class="p">,</span> <span class="mi">12</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">72</span><span class="p">]]),</span> <span class="n">mstype</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">labels_np</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">input_mask</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">))</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">labels</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">labels_np</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">output</span> <span class="o">=</span> <span class="n">loss</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">labels</span><span class="p">,</span> <span class="n">input_mask</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">output</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(1,)</span>
</pre></div>
</div>
</dd></dl>

<dl class="class">
<dt id="mindspore.nn.transformer.OpParallelConfig">
<em class="property">class </em><code class="sig-prename descclassname">mindspore.nn.transformer.</code><code class="sig-name descname">OpParallelConfig</code><span class="sig-paren">(</span><em class="sig-param">data_parallel=1</em>, <em class="sig-param">model_parallel=1</em><span class="sig-paren">)</span><a class="headerlink" href="#mindspore.nn.transformer.OpParallelConfig" title="Permalink to this definition">¶</a></dt>
<dd><p>用于设置数据并行和模型并行的OpParallelConfig。</p>
<p><strong>参数：</strong></p>
<ul class="simple">
<li><p><strong>data_parallel</strong> (int) - 表示数据并行度。默认值：1</p></li>
<li><p><strong>model_parallel</strong> (int) - 表示模型并行度。默认值：1</p></li>
</ul>
<p><strong>支持平台：</strong></p>
<p><code class="docutils literal notranslate"><span class="pre">Ascend</span></code> <code class="docutils literal notranslate"><span class="pre">GPU</span></code></p>
<p><strong>样例：</strong></p>
<div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">mindspore.nn.transformer</span> <span class="kn">import</span> <span class="n">OpParallelConfig</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">config</span><span class="o">=</span><span class="n">OpParallelConfig</span><span class="p">(</span><span class="n">data_parallel</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">model_parallel</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
</pre></div>
</div>
</dd></dl>

<dl class="class">
<dt id="mindspore.nn.transformer.FixedSparseAttention">
<em class="property">class </em><code class="sig-prename descclassname">mindspore.nn.transformer.</code><code class="sig-name descname">FixedSparseAttention</code><span class="sig-paren">(</span><em class="sig-param">batch_size</em>, <em class="sig-param">num_heads</em>, <em class="sig-param">size_per_head</em>, <em class="sig-param">block_size</em>, <em class="sig-param">seq_length=1024</em>, <em class="sig-param">num_different_global_patterns=4</em>, <em class="sig-param">parallel_config=default_dpmp_config</em><span class="sig-paren">)</span><a class="headerlink" href="#mindspore.nn.transformer.FixedSparseAttention" title="Permalink to this definition">¶</a></dt>
<dd><p>固定稀疏注意力层。</p>
<p>此接口实现了Sparse Transformer中使用的稀疏注意力原语。更多详情，请见论文（<a class="reference external" href="https://arxiv.org/abs/1904.10509">https://arxiv.org/abs/1904.10509</a>）。</p>
<p>具体来说，它包括以下内容：</p>
<ol class="arabic simple">
<li><p>正常注意力的更快实现（不计算上三角，并且融合了许多操作）。</p></li>
<li><p>如论文Sparse Transformers所述，“分散”和“固定”注意力的实现。</p></li>
</ol>
<p><strong>参数：</strong></p>
<ul class="simple">
<li><p><strong>batch_size</strong> (int) - 表示输入batch size的数量。</p></li>
<li><p><strong>num_heads</strong> (int) - 表示注意力头数。</p></li>
<li><p><strong>block_size</strong> (int) - 表示用来确定block size的整数。目前稀疏自注意力的实现基于稀疏块矩阵。此参数定义了稀疏矩阵块的大小。目前仅支持64。</p></li>
<li><p><strong>seq_length</strong> (int) - 表示输入序列的长度。目前只支持1024。</p></li>
<li><p><strong>num_different_global_patterns</strong> (int) - 表示用于确定不同的全局注意力数量。虽然全局注意力由局部的代表性的块决定，
但由于有多个头，所以每个头都可以使用不同的全局代表。目前只支持4。</p></li>
<li><p><strong>size_per_head</strong> (int) - 表示每个注意力头的向量大小。目前仅支持64和128。</p></li>
</ul>
<p><strong>输入：</strong></p>
<ul class="simple">
<li><p><strong>q</strong> (Tensor) - Tensor query (<code class="xref py py-class docutils literal notranslate"><span class="pre">mstype.fp16</span></code> [batch_size, seq_length, hidden_size])：表示上下文的query向量。</p></li>
<li><p><strong>k</strong> (Tensor) - Tensor key (<code class="xref py py-class docutils literal notranslate"><span class="pre">mstype.fp16</span></code> [batch_size, seq_length, hidden_size])：表示上下文的key向量。</p></li>
<li><p><strong>v</strong> (Tensor) - Tensor value (<code class="xref py py-class docutils literal notranslate"><span class="pre">mstype.fp16</span></code> [批次大小, seq_length, hidden_size])：表示上下文的value向量。</p></li>
<li><p><strong>attention_mask</strong> (Tensor) - Float Tensor the mask of (<code class="xref py py-class docutils literal notranslate"><span class="pre">mstype.fp32</span></code> , <code class="xref py py-class docutils literal notranslate"><span class="pre">mstype.fp16</span></code> [batch_size, seq_length, seq_length])：
表示掩码的下三角形矩阵。</p></li>
</ul>
<p><strong>输出：</strong></p>
<p>Tensor，shape为[batch_size, seq_length, hidden_size]。</p>
<p><strong>支持平台：</strong></p>
<p><code class="docutils literal notranslate"><span class="pre">Ascend</span></code></p>
<p><strong>样例：</strong></p>
<div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">mindspore</span> <span class="kn">import</span> <span class="n">dtype</span> <span class="k">as</span> <span class="n">mstype</span>
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">mindspore.nn.transformer</span> <span class="kn">import</span> <span class="n">FixedSparseAttention</span>
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">mindspore</span> <span class="kn">import</span> <span class="n">Tensor</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">model</span> <span class="o">=</span> <span class="n">FixedSparseAttention</span><span class="p">(</span><span class="n">batch_size</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span>
<span class="gp">... </span>                             <span class="n">num_heads</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span>
<span class="gp">... </span>                             <span class="n">size_per_head</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span>
<span class="gp">... </span>                             <span class="n">block_size</span><span class="o">=</span><span class="mi">64</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">q</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1024</span><span class="p">,</span> <span class="mi">8</span><span class="o">*</span><span class="mi">64</span><span class="p">)),</span> <span class="n">mstype</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">k</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1024</span><span class="p">,</span> <span class="mi">8</span><span class="o">*</span><span class="mi">64</span><span class="p">)),</span> <span class="n">mstype</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">v</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1024</span><span class="p">,</span> <span class="mi">8</span><span class="o">*</span><span class="mi">64</span><span class="p">)),</span> <span class="n">mstype</span><span class="o">.</span><span class="n">float16</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">attention_mask</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1024</span><span class="p">,</span> <span class="mi">1024</span><span class="p">)),</span> <span class="n">mstype</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">output</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">attention_mask</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="nb">print</span><span class="p">(</span><span class="n">output</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="go">(2, 1024, 512)</span>
</pre></div>
</div>
</dd></dl>

<dl class="class">
<dt id="mindspore.nn.transformer.MoEConfig">
<em class="property">class </em><code class="sig-prename descclassname">mindspore.nn.transformer.</code><code class="sig-name descname">MoEConfig</code><span class="sig-paren">(</span><em class="sig-param">expert_num=1</em>, <em class="sig-param">capacity_factor=1.1</em>, <em class="sig-param">aux_loss_factor=0.05</em>, <em class="sig-param">num_experts_chosen=1</em><span class="sig-paren">)</span><a class="headerlink" href="#mindspore.nn.transformer.MoEConfig" title="Permalink to this definition">¶</a></dt>
<dd><p>MoE (Mixture of Expert)的配置。</p>
<p><strong>参数：</strong></p>
<ul class="simple">
<li><p><strong>expert_num</strong> (int) - 表示使用的专家数量。默认值：1。</p></li>
<li><p><strong>capacity_factor</strong> (float) - 表示专家处理的容量关系，其值大于等于1.0。默认值：1.1。</p></li>
<li><p><strong>aux_loss_factor</strong> (float) - 表示负载均衡损失（由路由器产生）的平衡系数。相乘的结果会加到总损失函数中。此系数的值小于1.0。默认值：0.05。</p></li>
<li><p><strong>num_experts_chosen</strong> (int) - 表示每个标识选择的专家数量。默认值：1。</p></li>
</ul>
<p><strong>支持平台：</strong></p>
<p><code class="docutils literal notranslate"><span class="pre">Ascend</span></code> <code class="docutils literal notranslate"><span class="pre">GPU</span></code></p>
<p><strong>样例：</strong></p>
<div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">mindspore.nn.transformer</span> <span class="kn">import</span> <span class="n">MoEConfig</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">moe_config</span> <span class="o">=</span> <span class="n">MoEConfig</span><span class="p">(</span><span class="n">expert_num</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">capacity_factor</span><span class="o">=</span><span class="mf">5.0</span><span class="p">,</span> <span class="n">aux_loss_factor</span><span class="o">=</span><span class="mf">0.05</span><span class="p">,</span> <span class="n">num_experts_chosen</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
</pre></div>
</div>
</dd></dl>

<dl class="class">
<dt id="mindspore.nn.transformer.TransformerRecomputeConfig">
<em class="property">class </em><code class="sig-prename descclassname">mindspore.nn.transformer.</code><code class="sig-name descname">TransformerRecomputeConfig</code><span class="sig-paren">(</span><em class="sig-param">recompute=False</em>, <em class="sig-param">parallel_optimizer_comm_recompute=False</em>, <em class="sig-param">mp_comm_recompute=True</em>, <em class="sig-param">recompute_slice_activation=False</em><span class="sig-paren">)</span><a class="headerlink" href="#mindspore.nn.transformer.TransformerRecomputeConfig" title="Permalink to this definition">¶</a></dt>
<dd><p>Transformer的重计算配置接口。</p>
<p><strong>参数：</strong></p>
<ul class="simple">
<li><p><strong>recompute</strong> (bool) - 是否使能重计算。默认值为False。</p></li>
<li><p><strong>parallel_optimizer_comm_recompute</strong> (bool) - 指定由优化器切分产生的AllGather算子是否进行重计算。默认值为False。</p></li>
<li><p><strong>mp_comm_recompute</strong> (bool) - 指定由模型并行成分产生的通信算子是否进行重计算。默认值为False。</p></li>
<li><p><strong>recompute_slice_activation</strong> (bool) - 指定激活层是否切片保存。默认值为False。</p></li>
</ul>
<p><strong>支持平台：</strong></p>
<p><code class="docutils literal notranslate"><span class="pre">Ascend</span></code> <code class="docutils literal notranslate"><span class="pre">GPU</span></code></p>
<p><strong>样例：</strong></p>
<div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">mindspore.nn.transformer</span> <span class="kn">import</span> <span class="n">TransformerRecomputeConfig</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">config</span><span class="o">=</span><span class="n">TransformerRecomputeConfig</span><span class="p">(</span><span class="n">recompute</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">parallel_optimizer_comm_recompute</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> \
<span class="gp">... </span>                                  <span class="n">mp_comm_recompute</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">recompute_slice_activation</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
</pre></div>
</div>
</dd></dl>

</div>


           </div>
           
          </div>
          <footer>
    <div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
        <a href="mindspore.numpy.html" class="btn btn-neutral float-right" title="mindspore.numpy" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
        <a href="nn_probability/mindspore.nn.probability.distribution.Uniform.html" class="btn btn-neutral float-left" title="mindspore.nn.probability.distribution.Uniform" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
    </div>

  <hr/>

  <div role="contentinfo">
    <p>
        &#169; Copyright 2021, MindSpore.

    </p>
  </div>
    
    
    
    Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    
    <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    
    provided by <a href="https://readthedocs.org">Read the Docs</a>. 

</footer>
        </div>
      </div>

    </section>

  </div>
  

  <script type="text/javascript">
      jQuery(function () {
          SphinxRtdTheme.Navigation.enable(true);
      });
  </script>

  
  
    
   

</body>
</html>